{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e08c749d-b7bb-4f73-901a-78dd3c1ca078",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from typing import List\n",
    "\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "\n",
    "llm = ChatOpenAI(model_name=\"gpt-4-turbo\", request_timeout=180)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "64b2eac8-2aa1-4b8d-a5bc-fa8509838055",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_types = {\n",
    "    \"Simple Retrieval Queries\": \"These queries focus on basic data extraction, retrieving nodes or relationships based on straightforward criteria such as labels, properties, or direct relationships. Examples include fetching all nodes labeled as 'Person' or retrieving relationships of a specific type like 'EMPLOYED_BY'. Simple retrieval is essential for initial data inspections and basic reporting tasks. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "    \"Complex Retrieval Queries\": \"These advanced queries use the rich pattern-matching capabilities of Cypher to handle multiple node types and relationship patterns. They involve sophisticated filtering conditions and logical operations to extract nuanced insights from interconnected data points. An example could be finding all 'Person' nodes who work in a 'Department' with over 50 employees and have at least one 'REPORTS_TO' relationship. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "    \"Simple Aggregation Queries\": \"Simple aggregation involves calculating basic statistical metrics over properties of nodes or relationships, such as counting the number of nodes, averaging property values, or determining maximum and minimum values. These queries summarize data characteristics and support quick analytical conclusions. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "    \"Pathfinding Queries\": \"Specialized in exploring connections between nodes, these queries are used to find the shortest path, identify all paths up to a certain length, or explore possible routes within a network. They are essential for applications in network analysis, routing, logistics, and social network exploration. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "    \"Complex Aggregation Queries\": \"The most sophisticated category, these queries involve multiple aggregation functions and often group results over complex subgraphs. They calculate metrics like average number of reports per manager or total sales volume through a network, supporting strategic decision making and advanced reporting. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "    \"Verbose query\": \"These queries are characterized by their explicit and detailed specifications about the data retrieval process and the exact information needed. They involve elaborate instructions for navigating through complex data structures, specifying precise criteria for inclusion, exclusion, and sorting of data points. Verbose queries typically require the breakdown of each step in the querying process, from the initial identification of relevant data nodes and relationships to the intricate filtering and sorting mechanisms that must be applied. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a1dc8cae-3f51-4417-bbf6-809b28dcc37e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Question(BaseModel):\n",
    "    questions: List[str] = Field(\n",
    "        description=\"List of relevant questions for the particular graph schema. Make sure that questions can be answered with information from the schema and that the questions are diverse as possible. Make sure that the schema and the example values contains the information that can answer the questions! Do not ask questions that cannot be answered based on the provided schema. For example, if no information about subtitles can be found in the graph, don't ask any information about subtitles. Make sure to always limit the results to less than 10 results by saying 3 users, or top 5 movies, or similar.\"\n",
    "    )\n",
    "\n",
    "\n",
    "structured_llm = llm.with_structured_output(Question)\n",
    "\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.prompts import (\n",
    "    HumanMessagePromptTemplate,\n",
    "    SystemMessagePromptTemplate,\n",
    ")\n",
    "\n",
    "\n",
    "system_prompt = \"\"\"Your task is to generate 100 questions that are directly related to a specific graph schema in Neo4j. Each question should target distinct aspects of the schema, such as relationships between nodes, properties of nodes, or characteristics of node types. Ensure that the questions vary in complexity, covering basic, intermediate, and advanced queries.\n",
    "\n",
    "Avoid ambiguous questions. For clarity, an ambiguous question is one that can be interpreted in multiple ways or does not have a straightforward answer based on the schema. For example, avoid asking, \"What is related to this?\" without specifying the node type or relationship.\n",
    "Please design each question to yield a limited number of results, specifically between 3 to 10 results. This will ensure that the queries are precise and suitable for detailed analysis and training.\n",
    "The goal of these questions is to create a dataset for training AI models to convert natural language queries into Cypher queries effectively.\n",
    "It is vital that the database contains information that can answer the question!\n",
    "Make sure to generate 100 questions!\"\"\"\n",
    "\n",
    "default_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        SystemMessagePromptTemplate.from_template(\n",
    "            f\"{system_prompt} Follow these instructions: {{instructions}}\"\n",
    "        ),\n",
    "        HumanMessagePromptTemplate.from_template(\n",
    "            \"Make sure to create questions for the following graph schema:{input}\\n Here are some example nodes and relationship values: {values}. Don't use any values that aren't found in the schema or in provided values. Also, do not ask questions that there is no way to answer based on the schema or provided example values. Don't include question index or the sequence of the question in the list.\"\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "chain = default_prompt | structured_llm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "6ebceb15-2e06-4b45-9581-613f37a655d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEMO_URL = \"neo4j+s://demo.neo4jlabs.com\"\n",
    "DEMO_DATABASES = [\n",
    "    \"recommendations\",\n",
    "    \"buzzoverflow\",\n",
    "    \"bluesky\",\n",
    "    \"companies\",\n",
    "    \"fincen\",\n",
    "    \"gameofthrones\",\n",
    "    \"grandstack\",\n",
    "    \"movies\",\n",
    "    \"neoflix\",\n",
    "    \"network\",\n",
    "    \"northwind\",\n",
    "    \"offshoreleaks\",\n",
    "    \"slack\",\n",
    "    \"stackoverflow2\",\n",
    "    \"twitch\",\n",
    "    \"twitter\",\n",
    "];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d436d62d-02d0-4cb9-a19c-dc4f03fcdd41",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_questions = []\n",
    "for database in DEMO_DATABASES:\n",
    "    print(database)\n",
    "    graph = Neo4jGraph(\n",
    "        url=DEMO_URL,\n",
    "        database=database,\n",
    "        username=database,\n",
    "        password=database,\n",
    "        enhanced_schema=True,\n",
    "        sanitize=True,\n",
    "    )\n",
    "    schema = graph.schema\n",
    "    for type in query_types:\n",
    "        print(type)\n",
    "        instructions = f\"{type}: {query_types[type]}\"\n",
    "        values = graph.query(\n",
    "            \"\"\"MATCH (n)\n",
    "WHERE rand() > 0.6\n",
    "WITH n LIMIT 2\n",
    "CALL {\n",
    "    WITH n\n",
    "    MATCH p=(n)-[*3..3]-()\n",
    "    RETURN p LIMIT 1\n",
    "}\n",
    "RETURN p\"\"\"\n",
    "        )\n",
    "        questions = chain.invoke(\n",
    "            {\"input\": schema, \"instructions\": instructions, \"values\": values}\n",
    "        )\n",
    "        all_questions.extend(\n",
    "            [\n",
    "                {\"question\": el, \"type\": type, \"database\": database}\n",
    "                for el in questions.questions\n",
    "            ]\n",
    "        )\n",
    "        # Get values for each type separately"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ecfc1ed8-a83e-4474-b78f-a30867680191",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9824"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(all_questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "784a7ad5-f39a-4a4e-93b5-e3ffdf9a5d3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'question': 'Identify the 3 answers with the lowest scores provided to the most viewed questions.', 'type': 'Simple Aggregation Queries', 'database': 'stackoverflow2'}, {'question': \"List the first 3 countries where the filings have originated from entities based in 'Angola'.\", 'type': 'Simple Retrieval Queries', 'database': 'fincen'}, {'question': \"List the top 5 users having the most interactions with a specific key in 'INTERACTED' relationship.\", 'type': 'Complex Retrieval Queries', 'database': 'bluesky'}, {'question': 'What are the user ids of those who have asked a question with a favorites count of zero?', 'type': 'Simple Retrieval Queries', 'database': 'buzzoverflow'}, {'question': 'Which 3 products have the lowest unit price and are still in stock?', 'type': 'Simple Retrieval Queries', 'database': 'northwind'}, {'question': \"List all questions tagged with 'node.js' and asked by users with a reputation below 5000.\", 'type': 'Simple Aggregation Queries', 'database': 'buzzoverflow'}, {'question': \"What are the addresses of the businesses categorized as 'Coffee'?\", 'type': 'Complex Aggregation Queries', 'database': 'grandstack'}, {'question': \"Who are the top 5 most followed users that 'Neo4j' amplifies?\", 'type': 'Complex Aggregation Queries', 'database': 'twitter'}, {'question': 'Identify the 3 most common ship addresses for orders.', 'type': 'Verbose query', 'database': 'northwind'}, {'question': 'What are the first 3 movies with a runtime less than 90 minutes and an imdbRating over 7?', 'type': 'Simple Aggregation Queries', 'database': 'recommendations'}]\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "# Get 10 random elements from the list\n",
    "random_elements = random.sample(all_questions, 10)\n",
    "\n",
    "print(random_elements)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "7a9432ad-1079-4e7f-aa06-068642c96065",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame.from_records(all_questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "aa084985-7570-4edc-81f0-2a08ca67f58e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "database\n",
       "bluesky            135\n",
       "buzzoverflow       627\n",
       "companies          991\n",
       "fincen             617\n",
       "gameofthrones      399\n",
       "grandstack         828\n",
       "movies             767\n",
       "neoflix            937\n",
       "network            625\n",
       "northwind          818\n",
       "offshoreleaks      514\n",
       "recommendations    792\n",
       "slack              356\n",
       "stackoverflow2     313\n",
       "twitch             585\n",
       "twitter            520\n",
       "dtype: int64"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby(\"database\").size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c4b0d6da-fbc8-4629-b4ad-e8eb8781d90a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"text2cypher_questions.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "15d622db-cc1c-463b-a598-3ef6340ad374",
   "metadata": {},
   "outputs": [],
   "source": [
    "schemas = {}\n",
    "all_schemas = []\n",
    "for database in DEMO_DATABASES:\n",
    "    graph = Neo4jGraph(\n",
    "        url=DEMO_URL,\n",
    "        database=database,\n",
    "        username=database,\n",
    "        password=database,\n",
    "        enhanced_schema=True,\n",
    "        sanitize=True,\n",
    "    )\n",
    "    schema = graph.schema\n",
    "    schemas[database] = schema\n",
    "    all_schemas.append(\n",
    "        {\n",
    "            \"database\": database,\n",
    "            \"schema\": schema,\n",
    "            \"structured_schema\": graph.structured_schema,\n",
    "        }\n",
    "    )\n",
    "\n",
    "df_schemas = pd.DataFrame.from_records(all_schemas)\n",
    "df_schemas.to_csv(\"text2cypher_schemas.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "7dcbba6e-20ab-4d3a-a8cb-1b7f7c469c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create OpenAI Batch processing JSONL file\n",
    "\n",
    "system = \"\"\"Given an input question, convert it to a Cypher query. No pre-amble.\n",
    "Additional instructions:\n",
    "- Ensure that queries checking for non-null properties use `IS NOT NULL` in a straightforward manner.\n",
    "- Don't use `size((n)--(m))` for counting relationships. Instead use the new `count{(n)--(m))}` syntax.\n",
    "- Incorporate the new existential subqueries in examples where the query needs to check for the existence of a pattern.\n",
    "  Example: MATCH (p:Person)-[r:IS_FRIENDS_WITH]->(friend:Person)\n",
    "            WHERE exists{ (p)-[:WORKS_FOR]->(:Company {name: 'Neo4j'})}\n",
    "            RETURN p, r, friend\"\"\"\n",
    "\n",
    "\n",
    "def create_line(question, database, id):\n",
    "    schema = schemas[database]\n",
    "    user_message = f\"\"\"Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:\n",
    "    {schema}\n",
    "    \n",
    "    Question: {question}\n",
    "    Cypher query:\"\"\"\n",
    "\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system},\n",
    "        {\"role\": \"user\", \"content\": user_message},\n",
    "    ]\n",
    "    line = {\n",
    "        \"custom_id\": \"request-\" + str(id),\n",
    "        \"method\": \"POST\",\n",
    "        \"url\": \"/v1/chat/completions\",\n",
    "        \"body\": {\n",
    "            \"model\": \"gpt-4-turbo\",\n",
    "            \"messages\": messages,\n",
    "            \"max_tokens\": 1000,\n",
    "            \"temperature\": 0,\n",
    "        },\n",
    "    }\n",
    "    return line\n",
    "\n",
    "\n",
    "with open(\"text2cypher_batch_input.jsonl\", \"w\") as outfile:\n",
    "    for i, q in enumerate(all_questions):\n",
    "        line = create_line(q[\"question\"], q[\"database\"], i)\n",
    "        json.dump(line, outfile)\n",
    "        outfile.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a14f6a6-977f-43cc-bec4-1df55c51e4f6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
